library(grf)
library(ggplot2)Causal Random Forest
Introduction
Background to causal forests based on:
- Athey, Susan, Julie Tibshirani, and Stefan Wager. 2019. “Generalized Random Forests.” The Annals of Statistics 47 (2): 1148–78. https://doi.org/10.1214/18-AOS1709.
- Wager, Stefan, and Susan Athey. “Estimation and inference of heterogeneous treatment effects using random forests.” Journal of the American Statistical Association 113.523 (2018): 1228-1242.
Will use the grf example to demonstrate this: https://grf-labs.github.io/grf/articles/grf_guide.html
Background
In causal analysis, we aim to estimate the causal effect \(\tau\) based on a treatment \(W\). If data come from a randomized control trial, we assume no confounders, and the effect is just:
\[ \tau = E[Y_i(1) - Y_i(0)] \]
In observational studies, we have confounders (\(X_i\)), where we need to account for their effect on both \(Y_i\) and \(W_i\) for any individual \(i\). The effect can now be estimated from the following regression:
\[ Y_i = \tau W_i + \beta X_i + \epsilon_i \] Where
- \(\hat{\tau}\) is taken to be a good estimate of \(\tau\). This has the following assumptions:
- Conditional unconfoundedness. i.e. \(W_i\) is unconfounded given \(X_i\)
- \({Y_i(1), Y_i(0)} \perp W_i | X_i\)
- The error is assumed random (conditonal on \(W_i\) and \(X_i\))
- \(E[\epsilon_i|X_i,W_i] = 0\)
- The confounders have a linear effect on \(Y_i\)
- The treatment effect is constant
We can’t do anything about assumptions 1 and 2 as they are necessary to identify the model. But assumptions 3 and 4 relate to the model used and can be questioned.
Non-linear effects
For assumption 3, we can relax this through a standard semi-parametric approach:
\[ Y_i = \tau W_i + f(X_i) + \epsilon_i \]
Here, the baseline outcome for individual \(i\) is some unknown function of \(X_i\), which can be complex. The treatment is still constant and shifts the baseline estimate by \(\tau\).
The question is then how to define f() given that it is unknown. To do so, this takes advantage of the residual-on-residual regression (the Frisch-Waugh-Lovell approach). Robinson (1988) showed that this can be used with semi-parametric models. In grf terminology, we define two new objects:
- The propensity score: \(e(x) = E(W_i|X_i=x)\)
- The conditional mean of \(Y\): \(m(x) = E(Y_i|X_i=x) = f(x) + \tau e(x)\)
This can then be rewritten as:
\[ Y_i - m(x) = \tau (W_i - e(x)) + \epsilon_i \]
Robinson describes this as ‘centering’: plug in estimates of \(m(x)\) and \(e(x)\) are obtained, \(Y_i\) and \(W_i\) are centered,m then the residuals are regressed together.
In standard residual-on-residual approaches, we assume that the estimates of \(m(x)\) and \(e(x)\) are obtained through parametric means (e.g. OLS regression). These are replaced by machine learning models in double machine learning (DML) approaches, including causal RFs. Note that these plug-in estimates are obtained using ‘cross-fitting’. In this, the prediction of, for example, the outcome \(m(x)\) based on the confounders \(X_i\) for individual \(i\) is made with a model trained on all observations except \(i\). This avoids bias due to the (different) regularization strategies employed by the two models.
Non-constant treatment effects
In the original equation, \(\tau\) is assumed to be a constant factor across all individuals. To relax this, they use the idea of subgroups within the data set, each of which has it’s own regression, giving a value of \(\tau\) for each group ($here is still the coefficient of a linear model based on the centered outcome and treatment). The equation now becomes:
\[ Y_i = \tau(X_i) W_i + f(X_i) + \epsilon_i \]
where \(\tau (X_i)\) is the conditional average treatment effect for a given set of values of \(X_i\):
\[ E(Y_i(1) - Y_i(0)|X_i = x) \]
The next question is how to find these groups. We want to find subgroups where \(\tau\) can be assumed constant, in other words, we want to find a set of observations that we can calculate the residual-on-residual regression. Note that this is still a linear model, where the slope coefficient gives \(\tau\) for that set of observations:
\[ \tau(x) = lm(Y_i - \hat{m}^{-i}(X_i) \sim W_i - \hat{e}^{-i}(X_i), \mbox{weights} = 1(X_i \in \square (x))) \]
A better way of thinking about this may be that we are trying to build a random forest where the outcome is the slope of a regression line. The loss function prioritizes the biggest difference in slope at any split point.
CF Algorithm
- Fit first models (nuisance and propensity) using any standard machine learning algorithm (cross fit: use different subsets of data for the two models)
- Use first stage models to estimate values for the outcome (\(m(x)\)) and treatment (\(e(x)\))
- Calculate outcome residuals (\(Y' = Y_i - m(x)\)) and treatment residuals (\(T' = T_i - e(x)\))
- Fit second stage model (causal forest)
- Bootstrap data into in-bag (IB) and out-of-bag (OOB) sets
- Using the IB set, for each feature (\(X_j\)): a. Iterate across values of \(X_j\) to partition the OOB data into two sets (\(L\) and \(R\)) b. Test for imbalance (hyperparameter: 25-75% is maximum imbalance by default): skip if imbalance is greater than this c. Fit the following model in each of the two partitions:
- Left partition (\(L\)): \(Y'_L = \tau_L T'_L + e\)
- Right partition (\(R\)): \(Y'_R = \tau_R T'_R + e\)
- Calculate difference in treatment effect \(\delta \tau = |\tau_L - \tau_R|\)
- Repeat for all \(j\) features to find feature (and value) that maximizes \(\delta \tau\)
- make new data sets for both IB and OOB based on this split
- Repeat from ii. All subsequent steps will be based on \(\gt 1\) data subsets, so requires testing splits of \(X_j\) across all existing partitions
Example 1
Example of fitting a causal forest with a nonlinear relationship between X and \(\tau\):
Create some data (X[,1] is a confounder, X[,2] & X[,3] have non-linear impact on outcome)
set.seed(42)
n <- 2000
p <- 10
X <- matrix(rnorm(n * p), n, p)
X_test <- matrix(0, 101, p)
X_test[, 1] <- seq(-2, 2, length.out = 101)
W <- rbinom(n, 1, 0.4 + 0.2 * (X[, 1] > 0))
Y <- pmax(X[, 1], 0) * W + X[, 2] + pmin(X[, 3], 0) + rnorm(n)Plot X_1 and Y
plot_df = data.frame(X= X[,1], W = as.factor(W), Y = Y)
ggplot(plot_df, aes(x = X, y= Y, col = W)) +
geom_point() +
geom_smooth()`geom_smooth()` using method = 'gam' and formula = 'y ~ s(x, bs = "cs")'
Fit causal forest
tau_forest <- causal_forest(X, Y, W)
tau_forestGRF forest object of type causal_forest
Number of trees: 2000
Number of training samples: 2000
Variable importance:
1 2 3 4 5 6 7 8 9 10
0.704 0.037 0.034 0.042 0.033 0.030 0.028 0.025 0.032 0.036
Predict and plot
tau_hat <- predict(tau_forest, X_test)
plot_df = data.frame(X = rep(X_test[,1], 2),
tau = c(pmax(0, X_test[, 1]), tau_hat$predictions),
label = rep(c("Truth","Pred"), each = nrow(X_test)))
ggplot(plot_df, aes(x = X, y = tau, col = label)) +
geom_line() +
theme_bw()Example 1b
This uses the same data as before, but walks through a single split in a causal tree.
Step 1: nuisance model for treatment (the propensity model)
forest_W <- regression_forest(X, W, tune.parameters = "all")
W_hat <- predict(forest_W)$predictionsStep 2: nusiance model for outcome
forest_Y <- regression_forest(X, Y, tune.parameters = "all")
Y_hat <- predict(forest_Y)$predictionsStep 2b (optional): variable selection for forest
forest_Y_varimp <- variable_importance(forest_Y)
forest_Y_varimp [,1]
[1,] 0.064545274
[2,] 0.722531655
[3,] 0.179488680
[4,] 0.003975335
[5,] 0.004832954
[6,] 0.005150976
[7,] 0.004071005
[8,] 0.005604137
[9,] 0.005134106
[10,] 0.004665877
tau_forest <- causal_forest(X, Y, W,
W.hat = W_hat, Y.hat = Y_hat,
tune.parameters = "all")
tau_hat <- predict(tau_forest, X_test)
plot(X_test[, 1], tau_hat$predictions, ylim = range(tau_hat$predictions, 0, 2), xlab = "x", ylab = "tau", type = "l")
lines(X_test[, 1], pmax(0, X_test[, 1]), col = 2, lty = 2)library(animation)
breaks = seq(-2,2,by = 0.1)
nbreaks = length(breaks)
tau_df = data.frame(brks = breaks,
tau1 = rep(NA, nbreaks),
tau2 = rep(NA, nbreaks),
dtau = rep(NA, nbreaks))
plot_df = data.frame(X = X[,1],
W = as.factor(W),
W_hat = W_hat,
Y = Y,
Y_hat = Y_hat)
plot_df$tau_hat <- predict(tau_forest)$predictionsOutput at: test.gif
[1] TRUE
Pseudo-code for the full tree here:
nsplit = 10
out_df = data.frame(split = 1:nsplit,
X = rep(NA, nsplit),
tau_lo = rep(NA, nsplit),
tau_hi = rep(NA, nsplit))
data_list = list(plot_df)
breaks = seq(-2,2,by = 0.1)
nbreaks = length(breaks)
for (i in 1:nsplit) {
max_dtau = max_df = max_x = max_tau1 = max_tau2 = -9999
for (j in 1:length(data_list)) {
tmp_dl = data_list[[j]]
print(paste(i,j))
tau_df = data.frame(brks = breaks,
tau1 = rep(NA, nbreaks),
tau2 = rep(NA, nbreaks),
dtau = rep(NA, nbreaks))
for (k in 1:nbreaks) {
dat1 = tmp_dl |>
dplyr::filter(X < breaks[k])
dat2 = tmp_dl |>
dplyr::filter(X >= breaks[k])
if (nrow(dat1) > 10 & nrow(dat2) > 10) {
mod1 = lm(Y_hat ~ W_hat, dat1)
tau_df$tau1[k] = coef(mod1)[2]
mod2 = lm(Y_hat ~ W_hat, dat2)
tau_df$tau2[k] = coef(mod2)[2]
tau_df$dtau[k] = abs(tau_df$tau1[k] - tau_df$tau2[k])
}
}
rowID = which.max(tau_df$dtau)
if (any(!is.na(tau_df$dtau))) {
if (tau_df$dtau[rowID] > max_dtau) {
max_tau1 = tau_df$tau1[rowID]
max_tau2 = tau_df$tau2[rowID]
max_dtau = tau_df$dtau[rowID]
max_x = tau_df$brks[rowID]
max_df = j
}
}
}
print(paste(max_dtau, rowID, max_df, max_tau1, max_tau2))
out_df$X[i] = max_x
out_df$tau_lo[i] = max_tau1
out_df$tau_hi[i] = max_tau2
new_dl = list()
for (j in 1:length(data_list)) {
if (j == max_df) {
dat1 = data_list[[j]] |>
dplyr::filter(X < max_x)
dat2 = data_list[[j]] |>
dplyr::filter(X >= max_x)
new_dl = append(new_dl, list(dat1, dat2))
} else {
new_dl = append(new_dl, list(data_list[[j]]))
}
}
## Reassign data_list
data_list = new_dl
}
Example 2
From grf docs
In this section, we walk through an example application of GRF. The data we are using is from Bruhn et al. (2016), which conducted an RCT in Brazil in which high schools were randomly assigned a financial education program (in settings like this it is common to randomize at the school level to avoid student-level interference). This program increased student financial proficiency on average. Other outcomes are considered in the paper, we’ll focus on the financial proficiency score here. A processed copy of this data, containing student-level data from around 17 000 students, is stored on the github repo, it extracts basic student characteristics, as well as additional baseline survey responses we use as covariates (two of these are aggregated into an index by the authors to assess student’s ability to save, and their financial autonomy).
library(grf)
data <- read.csv("./data/bruhn2016.csv")
Y <- data$outcome
W <- data$treatment
school <- data$school
X <- data[-(1:3)]Around 30% have one or more missing covariates, the missingness pattern doesn’t seem to vary systematically between the treated and controls, so we’ll keep them in the analysis since GRF supports splitting on X’s with missing values.
sum(!complete.cases(X)) / nrow(X)[1] 0.2934852
t.test(W ~ !complete.cases(X))
Welch Two Sample t-test
data: W by !complete.cases(X)
t = -0.3923, df = 9490.1, p-value = 0.6948
alternative hypothesis: true difference in means between group FALSE and group TRUE is not equal to 0
95 percent confidence interval:
-0.01963191 0.01308440
sample estimates:
mean in group FALSE mean in group TRUE
0.5131730 0.5164467
Fitting causal forest
cf <- causal_forest(X, Y, W, W.hat = 0.5, clusters = school)